{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.manifold import TSNE\n",
    "from collections import OrderedDict\n",
    "\n",
    "# load other modules --> repo root path\n",
    "sys.path.insert(0, \"../\")\n",
    "\n",
    "import torch\n",
    "from utils import text, audio\n",
    "from utils import build_model\n",
    "from params.params import Params as hp\n",
    "from modules.tacotron2 import Tacotron"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode(model, inputs):\n",
    "    \n",
    "    inputs = [l.rstrip().split('|') for l in inputs if l]\n",
    "    encodeds = []\n",
    "\n",
    "    for i in inputs:\n",
    "        t = torch.LongTensor(text.to_sequence(i[0], use_phonemes=hp.use_phonemes))\n",
    "        l = torch.LongTensor([hp.languages.index(i[2])]) if hp.multi_language else None\n",
    "\n",
    "        if torch.cuda.is_available(): \n",
    "            t = t.cuda(non_blocking=True)\n",
    "            if l: l = l.cuda(non_blocking=True)\n",
    "            if s: s = s.cuda(non_blocking=True)\n",
    "\n",
    "        t.unsqueeze_(0)\n",
    "        embedded = model._embedding(t)\n",
    "        encoded = model._encoder(embedded, torch.LongTensor([t.size(1)]), l)\n",
    "        \n",
    "        unique_chars = list(set(i[0]))\n",
    "        char_ids = [unique_chars.index(x) for x in i[0]]\n",
    "        encodeds.append((i[0], char_ids, encoded.squeeze(0).cpu().detach().numpy()[:-1, :]))\n",
    "    \n",
    "    return encodeds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.load(checkpoint, map_location=\"cpu\")['parameters']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model = build_model(checkpoint)\n",
    "model.eval();\n",
    "print(hp.encoder_type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Encoded output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = [\"erlauben sie bitte, dass ich mich kurz vorstelle. ich heiße jana novakova.||fr\",\n",
    "          \"erlauben sie bitte, dass ich mich kurz vorstelle. ich heiße jana novakova.||de\",\n",
    "          \"les socialistes et les républicains sont venus apporter leurs voix à la majorité pour ce texte.||fr\",\n",
    "          \"les socialistes et les républicains sont venus apporter leurs voix à la majorité pour ce texte.||de\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = encode(model, inputs[:-2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tsne = [(t, c, TSNE(n_components=2).fit_transform(e)) for (t, c, e) in embeddings]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(19, 25))\n",
    "for i, (t, c, e) in enumerate(tsne):\n",
    "    ax = plt.subplot(4, 3, i + 1)   \n",
    "    for j in range(len(t)):\n",
    "        plt.scatter(e[j, 0], e[j, 1], c='k', marker=r\"$ {} $\".format(t[j].replace(\" \", \"/\")), alpha=0.7, s=50)\n",
    "        #ax.set_title(f'pool_layer={i + 1}')\n",
    "plt.tight_layout()\n",
    "plt.subplots_adjust(bottom=0.1, right=0.95, top=0.9)\n",
    "#cax = plt.axes([0.96, 0.6, 0.02, 0.3])\n",
    "#cbar = plt.colorbar(cax=cax, ticks=range(len(texts)))\n",
    "\n",
    "#cbar.ax.get_yaxis().set_ticks([])\n",
    "#for j, lab in enumerate(texts.keys()):\n",
    "#    cbar.ax.text(4.25, (2 * j + 1) / 2.25, lab, ha='center', va='center', fontsize=15)\n",
    "\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Language embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_similarity(similarity, languages):\n",
    "    fig = plt.figure(figsize=(6, 6))\n",
    "    ax = fig.add_subplot(111)\n",
    "    diagonal = similarity[1:,:-1].copy()\n",
    "    lower_indices = np.tril_indices(diagonal.shape[0])\n",
    "    lower = diagonal[lower_indices]\n",
    "    lower_min = np.min(lower)\n",
    "    lower_max = np.max(lower)\n",
    "    diagonal = (diagonal - lower_min) / (lower_max - lower_min)\n",
    "    cax = ax.matshow(np.tril(diagonal), interpolation='nearest')\n",
    "    fig.colorbar(cax)\n",
    "    ax.set_xticklabels([languages[0]]+languages[:-1], rotation='vertical')\n",
    "    ax.set_yticklabels(languages)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = model._encoder._embedding.weight.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "similarity = cosine_similarity(embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_similarity(similarity, hp.languages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_number = 2\n",
    "layer_weights = model._encoder._layers[layer_number]._convolution._bottleneck.weight.detach().numpy()\n",
    "bottleneck_embeddings = embeddings @ layer_weights.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bottleneck_similarity = cosine_similarity(bottleneck_embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_similarity(bottleneck_similarity, hp.languages)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
